{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.7.3"
    },
    "colab": {
      "name": "cifar-cnn.ipynb",
      "provenance": [],
      "include_colab_link": true
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "view-in-github",
        "colab_type": "text"
      },
      "source": [
        "<a href=\"https://colab.research.google.com/github/lukas/ml-class/blob/master/examples/keras-cifar/cifar-cnn.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FDUz5UGUML6R",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "!pip install wandb"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "f7mlk_kMDpuB",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 373
        },
        "outputId": "4487c370-97ce-449b-b25d-ec213c330870"
      },
      "source": [
        "import tensorflow as tf\n",
        "import wandb\n",
        "\n",
        "# Set Hyper-parameters\n",
        "wandb.init()\n",
        "config = wandb.config\n",
        "config.batch_size = 128\n",
        "config.epochs = 10\n",
        "config.learn_rate = 0.001\n",
        "config.dropout = 0.3\n",
        "config.dense_layer_nodes = 128\n",
        "\n",
        "# Load data\n",
        "class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',\n",
        "               'dog', 'frog', 'horse', 'ship', 'truck']\n",
        "num_classes = len(class_names)\n",
        "\n",
        "(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()\n",
        "X_train = X_train.astype('float32') / 255.\n",
        "X_test = X_test.astype('float32') / 255.\n",
        "\n",
        "# Convert class vectors to binary class matrices.\n",
        "y_train = tf.keras.utils.to_categorical(y_train, num_classes)\n",
        "y_test = tf.keras.utils.to_categorical(y_test, num_classes)\n",
        "\n",
        "# Define model\n",
        "model = tf.keras.models.Sequential()\n",
        "model.add(tf.keras.layers.Conv2D(32, (3, 3), padding='same',\n",
        "                                 input_shape=X_train.shape[1:], activation='relu'))\n",
        "model.add(tf.keras.layers.MaxPooling2D(pool_size=(2, 2)))\n",
        "model.add(tf.keras.layers.Dropout(config.dropout))\n",
        "\n",
        "model.add(tf.keras.layers.Flatten())\n",
        "model.add(tf.keras.layers.Dense(config.dense_layer_nodes, activation='relu'))\n",
        "model.add(tf.keras.layers.Dropout(config.dropout))\n",
        "model.add(tf.keras.layers.Dense(num_classes, activation='softmax'))\n",
        "model.compile(loss='categorical_crossentropy',\n",
        "              optimizer=tf.keras.optimizers.Adam(config.learn_rate),\n",
        "              metrics=['accuracy'])\n",
        "# log the number of total parameters\n",
        "config.total_params = model.count_params()\n",
        "print(\"Total params: \", config.total_params)\n",
        "\n",
        "model.fit(X_train, y_train, epochs=10, batch_size=128, validation_data=(X_test, y_test),\n",
        "          callbacks=[wandb.keras.WandbCallback(data_type=\"image\", labels=class_names, save_model=False)])"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "error",
          "ename": "ModuleNotFoundError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-1-4f49bf5c86a5>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtensorflow\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;31m# Set Hyper-parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0mwandb\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'wandb'",
            "",
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
          ]
        }
      ]
    }
  ]
}